{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ECE26A1028E24B3099AD6F7D0CECEEE1",
    "jupyter": {},
    "mdEditEnable": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "source": [
    "# 机器翻译和数据集\n",
    "\n",
    "\n",
    "机器翻译（MT）：将一段文本从一种语言自动翻译为另一种语言，用神经网络解决这个问题通常称为神经机器翻译（NMT）。\n",
    "主要特征：输出是单词序列而不是单个单词。 输出序列的长度可能与源序列的长度不同。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "id": "431E9AECF60849679FA0CDB3B39C25D5",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['fraeng6506', 'd2l9528', 'd2l6239']"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import os\n",
    "os.listdir('..')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "id": "F99F306F631B4E2F8F3464EEE0234CE1",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('/home/kesci/input/d2l9528/')\n",
    "import collections\n",
    "import d2l\n",
    "import zipfile\n",
    "from d2l.data.base import Vocab\n",
    "import time\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils import data\n",
    "from torch import optim"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "5DF1C65FE8E3493F83CCAD6D9CAD5E4A",
    "jupyter": {},
    "mdEditEnable": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "source": [
    "### 数据预处理\n",
    "将数据集清洗、转化为神经网络的输入minbatch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "C4C3229D46004ECF93F03C6BBC48AEDC",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Go.\tVa !\tCC-BY 2.0 (France) Attribution: tatoeba.org #2877272 (CM) & #1158250 (Wittydev)\n",
      "Hi.\tSalut !\tCC-BY 2.0 (France) Attribution: tatoeba.org #538123 (CM) & #509819 (Aiji)\n",
      "Hi.\tSalut.\tCC-BY 2.0 (France) Attribution: tatoeba.org #538123 (CM) & #4320462 (gillux)\n",
      "Run!\tCours !\tCC-BY 2.0 (France) Attribution: tatoeba.org #906328 (papabear) & #906331 (sacredceltic)\n",
      "Run!\tCourez !\tCC-BY 2.0 (France) Attribution: tatoeba.org #906328 (papabear) & #906332 (sacredceltic)\n",
      "Who?\tQui ?\tCC-BY 2.0 (France) Attribution: tatoeba.org #2083030 (CK) & #4366796 (gillux)\n",
      "Wow!\tÇa alors !\tCC-BY 2.0 (France) Attribution: tatoeba.org #52027 (Zifre) & #374631 (zmoo)\n",
      "Fire!\tAu feu !\tCC-BY 2.0 (France) Attribution: tatoeba.org #1829639 (Spamster) & #4627939 (sacredceltic)\n",
      "Help!\tÀ l'aide !\tCC-BY 2.0 (France) Attribution: tatoeba.org #435084 (lukaszpp) & #128430 (sysko)\n",
      "Jump.\tSaute.\tCC-BY 2.0 (France) Attribution: tatoeba.org #631038 (Shishir) & #2416938 (Phoenix)\n",
      "Stop!\tÇa suffit !\tCC-BY 2.0 (France) Attribution: tato\n"
     ]
    }
   ],
   "source": [
    "with open('/home/kesci/input/fraeng6506/fra.txt', 'r') as f:\n",
    "      raw_text = f.read()\n",
    "print(raw_text[0:1000])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "id": "BB101036A28849798E95206AAE93FDFC",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "go .\tva !\tcc-by 2 .0 (france) attribution: tatoeba .org #2877272 (cm) & #1158250 (wittydev)\n",
      "hi .\tsalut !\tcc-by 2 .0 (france) attribution: tatoeba .org #538123 (cm) & #509819 (aiji)\n",
      "hi .\tsalut .\tcc-by 2 .0 (france) attribution: tatoeba .org #538123 (cm) & #4320462 (gillux)\n",
      "run !\tcours !\tcc-by 2 .0 (france) attribution: tatoeba .org #906328 (papabear) & #906331 (sacredceltic)\n",
      "run !\tcourez !\tcc-by 2 .0 (france) attribution: tatoeba .org #906328 (papabear) & #906332 (sacredceltic)\n",
      "who?\tqui ?\tcc-by 2 .0 (france) attribution: tatoeba .org #2083030 (ck) & #4366796 (gillux)\n",
      "wow !\tça alors !\tcc-by 2 .0 (france) attribution: tatoeba .org #52027 (zifre) & #374631 (zmoo)\n",
      "fire !\tau feu !\tcc-by 2 .0 (france) attribution: tatoeba .org #1829639 (spamster) & #4627939 (sacredceltic)\n",
      "help !\tà l'aide !\tcc-by 2 .0 (france) attribution: tatoeba .org #435084 (lukaszpp) & #128430 (sysko)\n",
      "jump .\tsaute .\tcc-by 2 .0 (france) attribution: tatoeba .org #631038 (shishir) & #2416938 (phoenix)\n",
      "stop !\tça suffit !\tcc-b\n"
     ]
    }
   ],
   "source": [
    "def preprocess_raw(text):\n",
    "    text = text.replace('\\u202f', ' ').replace('\\xa0', ' ')\n",
    "    out = ''\n",
    "    for i, char in enumerate(text.lower()):\n",
    "        if char in (',', '!', '.') and i > 0 and text[i-1] != ' ':\n",
    "            out += ' '\n",
    "        out += char\n",
    "    return out\n",
    "\n",
    "text = preprocess_raw(raw_text)\n",
    "print(text[0:1000])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6F249A8DDE8E41EA8BF5DF0C62196D30",
    "jupyter": {},
    "mdEditEnable": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "source": [
    "字符在计算机里是以编码的形式存在，我们通常所用的空格是 \\x20 ，是在标准ASCII可见字符 0x20~0x7e 范围内。\n",
    "而 \\xa0 属于 latin1 （ISO/IEC_8859-1）中的扩展字符集字符，代表不间断空白符nbsp(non-breaking space)，超出gbk编码范围，是需要去除的特殊字符。再数据预处理的过程中，我们首先需要对数据进行清洗。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "3F005D61B217439D988A086B843DF940",
    "jupyter": {},
    "mdEditEnable": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "source": [
    "### 分词\n",
    "字符串---单词组成的列表"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "id": "AC2D82A7081943B0BE94018061BC62AF",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "([['go', '.'], ['hi', '.'], ['hi', '.']],\n",
       " [['va', '!'], ['salut', '!'], ['salut', '.']])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "num_examples = 50000\n",
    "source, target = [], []\n",
    "for i, line in enumerate(text.split('\\n')):\n",
    "    if i > num_examples:\n",
    "        break\n",
    "    parts = line.split('\\t')\n",
    "    if len(parts) >= 2:\n",
    "        source.append(parts[0].split(' '))\n",
    "        target.append(parts[1].split(' '))\n",
    "        \n",
    "source[0:3], target[0:3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "id": "7589E7D345B3463A8F0F4574ED6EDA9A",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<img src=\"https://cdn.kesci.com/rt_upload/7589E7D345B3463A8F0F4574ED6EDA9A/q5jefa8ffq.svg\">"
      ],
      "text/plain": [
       "<Figure size 252x180 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "d2l.set_figsize()\n",
    "d2l.plt.hist([[len(l) for l in source], [len(l) for l in target]],label=['source', 'target'])\n",
    "d2l.plt.legend(loc='upper right');"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "A9ABB7E025F14892BD21C7ADA88F848C",
    "jupyter": {},
    "mdEditEnable": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "source": [
    "### 建立词典\n",
    "单词组成的列表---单词id组成的列表"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "id": "D432CD6C4A644196831393AEF3EF6A06",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3789"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def build_vocab(tokens):\n",
    "    tokens = [token for line in tokens for token in line]\n",
    "    return d2l.data.base.Vocab(tokens, min_freq=3, use_special_tokens=True)\n",
    "\n",
    "src_vocab = build_vocab(source)\n",
    "len(src_vocab)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "id": "B1461B583FF94E978D53F05EB8EFCFB5",
    "jupyter": {},
    "mdEditEnable": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "source": [
    "\n",
    "![Image Name](https://cdn.kesci.com/upload/image/q5jc5ga5gy.png?imageView2/0/w/960/h/960)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "FF4EFEF2AB424A5587358006DE570FAA",
    "jupyter": {},
    "mdEditEnable": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "source": [
    "### 载入数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "id": "E556507EFE624C93B9A181F68B2CF64A",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[38, 4, 0, 0, 0, 0, 0, 0, 0, 0]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def pad(line, max_len, padding_token):\n",
    "    if len(line) > max_len:\n",
    "        return line[:max_len]\n",
    "    return line + [padding_token] * (max_len - len(line))\n",
    "pad(src_vocab[source[0]], 10, src_vocab.pad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "id": "E704BB99768F474E81871942B81E0665",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def build_array(lines, vocab, max_len, is_source):\n",
    "    lines = [vocab[line] for line in lines]\n",
    "    if not is_source:\n",
    "        lines = [[vocab.bos] + line + [vocab.eos] for line in lines]\n",
    "    array = torch.tensor([pad(line, max_len, vocab.pad) for line in lines])\n",
    "    valid_len = (array != vocab.pad).sum(1) #第一个维度\n",
    "    return array, valid_len"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "DF66F496328148489E11D0D826291CB6",
    "jupyter": {},
    "mdEditEnable": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "source": [
    "\n",
    "![Image Name](https://cdn.kesci.com/upload/image/q5jc6e5tt1.png?imageView2/0/w/960/h/960)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "id": "7BB02BF246AC4302824AA6CFA95A40D8",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def load_data_nmt(batch_size, max_len): # This function is saved in d2l.\n",
    "    src_vocab, tgt_vocab = build_vocab(source), build_vocab(target)\n",
    "    src_array, src_valid_len = build_array(source, src_vocab, max_len, True)\n",
    "    tgt_array, tgt_valid_len = build_array(target, tgt_vocab, max_len, False)\n",
    "    train_data = data.TensorDataset(src_array, src_valid_len, tgt_array, tgt_valid_len)\n",
    "    train_iter = data.DataLoader(train_data, batch_size, shuffle=True)\n",
    "    return src_vocab, tgt_vocab, train_iter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "id": "DF377DA83D8C409CBC7A48ABDC3CCFDB",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "X = tensor([[   5,   24,    3,    4,    0,    0,    0,    0],\n",
      "        [  12, 1388,    7,    3,    4,    0,    0,    0]], dtype=torch.int32) \n",
      "Valid lengths for X = tensor([4, 5]) \n",
      "Y = tensor([[   1,   23,   46,    3,    3,    4,    2,    0],\n",
      "        [   1,   15,  137,   27, 4736,    4,    2,    0]], dtype=torch.int32) \n",
      "Valid lengths for Y = tensor([7, 7])\n"
     ]
    }
   ],
   "source": [
    "src_vocab, tgt_vocab, train_iter = load_data_nmt(batch_size=2, max_len=8)\n",
    "for X, X_valid_len, Y, Y_valid_len, in train_iter:\n",
    "    print('X =', X.type(torch.int32), '\\nValid lengths for X =', X_valid_len,\n",
    "        '\\nY =', Y.type(torch.int32), '\\nValid lengths for Y =', Y_valid_len)\n",
    "    break"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "id": "A2C40BD7BF1942A48EFDCE97B0C55D30",
    "jupyter": {},
    "mdEditEnable": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "source": [
    "# Encoder-Decoder\n",
    " encoder：输入到隐藏状态  \n",
    " decoder：隐藏状态到输出\n",
    "\n",
    "\n",
    "![Image Name](https://cdn.kesci.com/upload/image/q5jcat3c8m.png?imageView2/0/w/640/h/640)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "id": "86647F7276B743F1B4D466AF8055C27A",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "class Encoder(nn.Module):\n",
    "    def __init__(self, **kwargs):\n",
    "        super(Encoder, self).__init__(**kwargs)\n",
    "\n",
    "    def forward(self, X, *args):\n",
    "        raise NotImplementedError"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "id": "A894858F5FA94C3599B3CE9D4EF88964",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "class Decoder(nn.Module):\n",
    "    def __init__(self, **kwargs):\n",
    "        super(Decoder, self).__init__(**kwargs)\n",
    "\n",
    "    def init_state(self, enc_outputs, *args):\n",
    "        raise NotImplementedError\n",
    "\n",
    "    def forward(self, X, state):\n",
    "        raise NotImplementedError"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "id": "26FFA2FB4ED04DC6BA3BAAA4A522217C",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "class EncoderDecoder(nn.Module):\n",
    "    def __init__(self, encoder, decoder, **kwargs):\n",
    "        super(EncoderDecoder, self).__init__(**kwargs)\n",
    "        self.encoder = encoder\n",
    "        self.decoder = decoder\n",
    "\n",
    "    def forward(self, enc_X, dec_X, *args):\n",
    "        enc_outputs = self.encoder(enc_X, *args)\n",
    "        dec_state = self.decoder.init_state(enc_outputs, *args)\n",
    "        return self.decoder(dec_X, dec_state)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "2296785759574D6C86BD5DFBBC0A2E90",
    "jupyter": {},
    "mdEditEnable": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "source": [
    "可以应用在对话系统、生成式任务中。"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "id": "857C6A5FDBA64866B66A42F8FD1A6A88",
    "jupyter": {},
    "mdEditEnable": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "source": [
    "# Sequence to Sequence模型\n",
    "\n",
    "### 模型：\n",
    "训练  \n",
    "![Image Name](https://cdn.kesci.com/upload/image/q5jc7a53pt.png?imageView2/0/w/640/h/640)\n",
    "预测\n",
    "\n",
    "![Image Name](https://cdn.kesci.com/upload/image/q5jcecxcba.png?imageView2/0/w/640/h/640)\n",
    "\n",
    "\n",
    "\n",
    "### 具体结构：\n",
    "![Image Name](https://cdn.kesci.com/upload/image/q5jccjhkii.png?imageView2/0/w/500/h/500)\n",
    "### Encoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "id": "7886803BFCBC4D85A79A20B43FC5711A",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "class Seq2SeqEncoder(d2l.Encoder):\n",
    "    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,\n",
    "                 dropout=0, **kwargs):\n",
    "        super(Seq2SeqEncoder, self).__init__(**kwargs)\n",
    "        self.num_hiddens=num_hiddens\n",
    "        self.num_layers=num_layers\n",
    "        self.embedding = nn.Embedding(vocab_size, embed_size)\n",
    "        self.rnn = nn.LSTM(embed_size,num_hiddens, num_layers, dropout=dropout)\n",
    "   \n",
    "    def begin_state(self, batch_size, device):\n",
    "        return [torch.zeros(size=(self.num_layers, batch_size, self.num_hiddens),  device=device),\n",
    "                torch.zeros(size=(self.num_layers, batch_size, self.num_hiddens),  device=device)]\n",
    "    def forward(self, X, *args):\n",
    "        X = self.embedding(X) # X shape: (batch_size, seq_len, embed_size)\n",
    "        X = X.transpose(0, 1)  # RNN needs first axes to be time\n",
    "        # state = self.begin_state(X.shape[1], device=X.device)\n",
    "        out, state = self.rnn(X)\n",
    "        # The shape of out is (seq_len, batch_size, num_hiddens).\n",
    "        # state contains the hidden state and the memory cell\n",
    "        # of the last time step, the shape is (num_layers, batch_size, num_hiddens)\n",
    "        return out, state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "id": "F48A7CA0591641A5A0AC12882EA3DA88",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([7, 4, 16]), 2, torch.Size([2, 4, 16]), torch.Size([2, 4, 16]))"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "encoder = Seq2SeqEncoder(vocab_size=10, embed_size=8,num_hiddens=16, num_layers=2)\n",
    "X = torch.zeros((4, 7),dtype=torch.long)\n",
    "output, state = encoder(X)\n",
    "output.shape, len(state), state[0].shape, state[1].shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "C9CE0E8FBC8A4F7389B173DD62206CD5",
    "jupyter": {},
    "mdEditEnable": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "source": [
    "# Decoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "id": "75A71227061547E0AF5EA9A5EDAE77F3",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "class Seq2SeqDecoder(d2l.Decoder):\n",
    "    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,\n",
    "                 dropout=0, **kwargs):\n",
    "        super(Seq2SeqDecoder, self).__init__(**kwargs)\n",
    "        self.embedding = nn.Embedding(vocab_size, embed_size)\n",
    "        self.rnn = nn.LSTM(embed_size,num_hiddens, num_layers, dropout=dropout)\n",
    "        self.dense = nn.Linear(num_hiddens,vocab_size)\n",
    "\n",
    "    def init_state(self, enc_outputs, *args):\n",
    "        return enc_outputs[1]\n",
    "\n",
    "    def forward(self, X, state):\n",
    "        X = self.embedding(X).transpose(0, 1)\n",
    "        out, state = self.rnn(X, state)\n",
    "        # Make the batch to be the first dimension to simplify loss computation.\n",
    "        out = self.dense(out).transpose(0, 1)\n",
    "        return out, state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "id": "8B922C6FCF5643CFA46ECB509A109C8B",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([4, 7, 10]), 2, torch.Size([2, 4, 16]), torch.Size([2, 4, 16]))"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "decoder = Seq2SeqDecoder(vocab_size=10, embed_size=8,num_hiddens=16, num_layers=2)\n",
    "state = decoder.init_state(encoder(X))\n",
    "out, state = decoder(X, state)\n",
    "out.shape, len(state), state[0].shape, state[1].shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "4E7FF047303445E881C9C7617FC845A9",
    "jupyter": {},
    "mdEditEnable": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "source": [
    "### 损失函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "id": "28493F1E5E1B47CA9FD60AF2AF1E2491",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def SequenceMask(X, X_len,value=0):\n",
    "    maxlen = X.size(1)\n",
    "    mask = torch.arange(maxlen)[None, :].to(X_len.device) < X_len[:, None]   \n",
    "    X[~mask]=value\n",
    "    return X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "id": "EAAF6E804ADA42A8892B4445F503E004",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1, 0, 0],\n",
       "        [4, 5, 0]])"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X = torch.tensor([[1,2,3], [4,5,6]])\n",
    "SequenceMask(X,torch.tensor([1,2]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "id": "8AA5E43A0A8F4FD38C6E5AF16A747C6C",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[ 1.,  1.,  1.,  1.],\n",
       "         [-1., -1., -1., -1.],\n",
       "         [-1., -1., -1., -1.]],\n",
       "\n",
       "        [[ 1.,  1.,  1.,  1.],\n",
       "         [ 1.,  1.,  1.,  1.],\n",
       "         [-1., -1., -1., -1.]]])"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X = torch.ones((2,3, 4))\n",
    "SequenceMask(X, torch.tensor([1,2]),value=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "id": "2FA64864BA224A49B54C188323F7044B",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "class MaskedSoftmaxCELoss(nn.CrossEntropyLoss):\n",
    "    # pred shape: (batch_size, seq_len, vocab_size)\n",
    "    # label shape: (batch_size, seq_len)\n",
    "    # valid_length shape: (batch_size, )\n",
    "    def forward(self, pred, label, valid_length):\n",
    "        # the sample weights shape should be (batch_size, seq_len)\n",
    "        weights = torch.ones_like(label)\n",
    "        weights = SequenceMask(weights, valid_length).float()\n",
    "        self.reduction='none'\n",
    "        output=super(MaskedSoftmaxCELoss, self).forward(pred.transpose(1,2), label)\n",
    "        return (output*weights).mean(dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "id": "F59552886EC2485E8567BCD3C33F6D06",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([2.3026, 1.7269, 0.0000])"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loss = MaskedSoftmaxCELoss()\n",
    "loss(torch.ones((3, 4, 10)), torch.ones((3,4),dtype=torch.long), torch.tensor([4,3,0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "E8BFAA3559C240B6AFF8D4C053CCEE44",
    "jupyter": {},
    "mdEditEnable": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "source": [
    "### 训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "id": "158F3AF8015D44F6891C579EAE3AC2F6",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def train_ch7(model, data_iter, lr, num_epochs, device):  # Saved in d2l\n",
    "    model.to(device)\n",
    "    optimizer = optim.Adam(model.parameters(), lr=lr)\n",
    "    loss = MaskedSoftmaxCELoss()\n",
    "    tic = time.time()\n",
    "    for epoch in range(1, num_epochs+1):\n",
    "        l_sum, num_tokens_sum = 0.0, 0.0\n",
    "        for batch in data_iter:\n",
    "            optimizer.zero_grad()\n",
    "            X, X_vlen, Y, Y_vlen = [x.to(device) for x in batch]\n",
    "            Y_input, Y_label, Y_vlen = Y[:,:-1], Y[:,1:], Y_vlen-1\n",
    "            \n",
    "            Y_hat, _ = model(X, Y_input, X_vlen, Y_vlen)\n",
    "            l = loss(Y_hat, Y_label, Y_vlen).sum()\n",
    "            l.backward()\n",
    "\n",
    "            with torch.no_grad():\n",
    "                d2l.grad_clipping_nn(model, 5, device)\n",
    "            num_tokens = Y_vlen.sum().item()\n",
    "            optimizer.step()\n",
    "            l_sum += l.sum().item()\n",
    "            num_tokens_sum += num_tokens\n",
    "        if epoch % 50 == 0:\n",
    "            print(\"epoch {0:4d},loss {1:.3f}, time {2:.1f} sec\".format( \n",
    "                  epoch, (l_sum/num_tokens_sum), time.time()-tic))\n",
    "            tic = time.time()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "id": "08CBFD23FD6C40E890E96ADFA0FE161F",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch   50,loss 0.093, time 38.2 sec\n",
      "epoch  100,loss 0.046, time 37.9 sec\n",
      "epoch  150,loss 0.032, time 36.8 sec\n",
      "epoch  200,loss 0.027, time 37.5 sec\n",
      "epoch  250,loss 0.026, time 37.8 sec\n",
      "epoch  300,loss 0.025, time 37.3 sec\n"
     ]
    }
   ],
   "source": [
    "embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.0\n",
    "batch_size, num_examples, max_len = 64, 1e3, 10\n",
    "lr, num_epochs, ctx = 0.005, 300, d2l.try_gpu()\n",
    "src_vocab, tgt_vocab, train_iter = d2l.load_data_nmt(\n",
    "    batch_size, max_len,num_examples)\n",
    "encoder = Seq2SeqEncoder(\n",
    "    len(src_vocab), embed_size, num_hiddens, num_layers, dropout)\n",
    "decoder = Seq2SeqDecoder(\n",
    "    len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)\n",
    "model = d2l.EncoderDecoder(encoder, decoder)\n",
    "train_ch7(model, train_iter, lr, num_epochs, ctx)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "7F4C5D893110439FB07B542D4C161ECD",
    "jupyter": {},
    "mdEditEnable": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "source": [
    "### 测试"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "id": "13CC26D191FF43BF8BB79FB14CDF7BAF",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def translate_ch7(model, src_sentence, src_vocab, tgt_vocab, max_len, device):\n",
    "    src_tokens = src_vocab[src_sentence.lower().split(' ')]\n",
    "    src_len = len(src_tokens)\n",
    "    if src_len < max_len:\n",
    "        src_tokens += [src_vocab.pad] * (max_len - src_len)\n",
    "    enc_X = torch.tensor(src_tokens, device=device)\n",
    "    enc_valid_length = torch.tensor([src_len], device=device)\n",
    "    # use expand_dim to add the batch_size dimension.\n",
    "    enc_outputs = model.encoder(enc_X.unsqueeze(dim=0), enc_valid_length)\n",
    "    dec_state = model.decoder.init_state(enc_outputs, enc_valid_length)\n",
    "    dec_X = torch.tensor([tgt_vocab.bos], device=device).unsqueeze(dim=0)\n",
    "    predict_tokens = []\n",
    "    for _ in range(max_len):\n",
    "        Y, dec_state = model.decoder(dec_X, dec_state)\n",
    "        # The token with highest score is used as the next time step input.\n",
    "        dec_X = Y.argmax(dim=2)\n",
    "        py = dec_X.squeeze(dim=0).int().item()\n",
    "        if py == tgt_vocab.eos:\n",
    "            break\n",
    "        predict_tokens.append(py)\n",
    "    return ' '.join(tgt_vocab.to_tokens(predict_tokens))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "id": "D978A9B7B69B4E99942AE11491A7273B",
    "jupyter": {},
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Go . => va !\n",
      "Wow ! => <unk> !\n",
      "I'm OK . => ça va .\n",
      "I won ! => j'ai gagné !\n"
     ]
    }
   ],
   "source": [
    "for sentence in ['Go .', 'Wow !', \"I'm OK .\", 'I won !']:\n",
    "    print(sentence + ' => ' + translate_ch7(\n",
    "        model, sentence, src_vocab, tgt_vocab, max_len, ctx))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "id": "B37CC2560DBC41A786C01BB5AB52CBFD",
    "jupyter": {},
    "mdEditEnable": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "source": [
    "# Beam Search\n",
    "简单greedy search：\n",
    "\n",
    "![Image Name](https://cdn.kesci.com/upload/image/q5jchqoppn.png?imageView2/0/w/440/h/440)\n",
    "\n",
    "维特比算法：选择整体分数最高的句子（搜索空间太大）\n",
    "集束搜索：\n",
    "\n",
    "![Image Name](https://cdn.kesci.com/upload/image/q5jcia86z1.png?imageView2/0/w/640/h/640)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "D6CFEE200E5D42E2BD3334E612C8A41A",
    "jupyter": {},
    "slideshow": {
     "slide_type": "slide"
    },
    "tags": []
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
